import warnings
import logging
import os
from copy import deepcopy
import tqdm
# 抑制警告
warnings.filterwarnings('ignore')
logging.getLogger('omnigibson').setLevel(logging.ERROR)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'  # 使用镜像站点

# 如果需要更细致的控制，可以针对特定警告
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', message='.*particleRemover.*')
warnings.filterwarnings('ignore', message='.*ClothStateMixin.*')

ALGO_NAME = 'PolicyDecorator-DiffusionPolicy'

import os
import argparse
import random
from distutils.util import strtobool

os.environ["OMP_NUM_THREADS"] = "1"

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

from datetime import datetime
from collections import defaultdict



import os
import yaml
import torch
import numpy as np


import os
import torch
import numpy as np
import omnigibson as og
import yaml
import time
import os
import json
import numpy as np
import torch as th
import yaml
from scipy.spatial.transform import Rotation as R
from multiprocessing import shared_memory, Queue, Event
import omnigibson as og
import omnigibson.lazy as lazy
from omnigibson.macros import gm

import torch.nn.functional as F
import sys
import os
import yaml
import time
from multiprocessing.managers import SharedMemoryManager
import numpy as np
import torch
import dill
import hydra
from omegaconf import OmegaConf
import sys
sys.path.append("/home/admin01/project") 
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
import h5py
import sys
sys.path.append("/home/admin01/project/DexDiffusionPolicy")
import yaml  # 用于读取配置文件
import dill  # 用于加载模型
import hydra  # 用于加载diffusion policy
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from Benchmark_results.evaluate.env_aug_iter import EnvAugmentor
from omnigibson.envs import DataCollectionWrapper 
from omnigibson.tasks.task_base import BaseTask
ENV_AUGMENTOR_CONFIG_PATH = "/path/to/aug_base.yaml"
DEFAULT_OBJ_CONFIG_DIR = "/path/to/d3" 
TEST_OUT_PATH = "/path/to/test_datagen/task_name/datagen" #测试时会生成数据，数据保存路径

base_policy_ckpt = "/path/to/lift_epoch=4650-train_loss=0.002.ckpt"
res_policy_path="/path/to/checkpoints/1040000.pt" 

mix=False #是否混合输出
RES_SCALE=0.1 # RL参与阶段
NORMAL_RES_CALE=0  # RL不参与阶段
random_object_flag=True
random_vision_flag=True

CONFIG_FILE_PATH = "/path/to/common/configs_env.yaml"
log_save_path = "/path/to/log" #log results
json_save_path = "/path/to/json" #log results

from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path_name = f"data_{timestamp}.hdf5"
current_filename = [output_path_name]
domain_cfg = None

class MyTask(BaseTask):
    

    def _step_reward(self, env, action, info=None):
        global reward_stage_flag
        # global obj_start_pos_left,obj_start_pos_right,obj_start_pos_front,obj_start_pos_back
        r1,r2=0,0
        robot = env.robots[0]
        obj = env.scene.object_registry("name", "ball0")
        obj_pos = torch.tensor(obj.get_position())
        obj_quat=torch.tensor(obj.get_orientation())
        # 计算箱子倾斜角度 - 将四元数转换为欧拉角
        from scipy.spatial.transform import Rotation as R
        # 四元数顺序已经是(x,y,z,w)，直接传入即可
        r = R.from_quat([obj_quat[0], obj_quat[1], obj_quat[2], obj_quat[3]])
        euler_angles = r.as_euler('xyz', degrees=True)
        x_tilt = abs(euler_angles[0])
        y_tilt = abs(euler_angles[1])
        max_tilt_angle = max(x_tilt, y_tilt)# 计算x和y轴的倾斜角度（与水平面的夹角）
        # board=env.scene.object_registry("name", "ball1")
        # initial_height = board.get_position()[2]
        # 设置倾斜惩罚阈值和系数
        tilt_threshold = 30.0  # 
        max_tilt_penalty = 5.0  # 最大惩罚值
        # 计算倾斜惩罚
        if max_tilt_angle > tilt_threshold:
            # 线性增加惩罚
            tilt_penalty = -min(max_tilt_penalty, (max_tilt_angle - tilt_threshold) /max_tilt_penalty)
        else:
            tilt_penalty = 0.0
        bbox_center_in_world, bbox_quat_in_world, bbox_extent_in_base_frame, bbox_center_in_desired_frame = obj.get_base_aligned_bbox(visual=False)
        grasp_point_left=obj_pos.clone()
        grasp_point_left[1]=grasp_point_left[1]+bbox_extent_in_base_frame[1]/2 +0.015
        grasp_point_left[2]=grasp_point_left[2]-0.015 #-0.015
        grasp_point_right=obj_pos.clone()
        grasp_point_right[1]=grasp_point_right[1]-bbox_extent_in_base_frame[1]/2 -0.015
        grasp_point_right[2]=grasp_point_right[2]-0.015
        ball_left_pos=grasp_point_left
        ball_right_pos=grasp_point_right
        
        # 右手：计算当前距离：计算物体到右手三个手指（拇指、食指、中指）距离的总和
        right_hand_thumb_pos = robot._links["hand2_link_1_3"].get_position() #右手拇指的位置
        right_hand_index_pos = robot._links["hand2_link_2_2"].get_position() #右手食指的位置
        right_hand_middle_pos = robot._links["hand2_link_3_2"].get_position() #右手中指的位置
        right_hand_ring_pos = robot._links["hand2_link_4_2"].get_position() #右手无名指的位置
        right_fingers_pos = [right_hand_index_pos, right_hand_middle_pos,right_hand_ring_pos]
        right_current_fingers = [torch.norm(ball_right_pos - finger_pos, p=2, dim=-1) for finger_pos in right_fingers_pos]
        right_current_finger_dist = sum(right_current_fingers)
        right_avg_distance= sum(right_current_fingers) / len(right_current_fingers)
        
        # 左手：计算当前距离：计算物体到右手三个手指（拇指、食指、中指）距离的总和
        left_hand_thumb_pos = robot._links["hand1_link_1_3"].get_position()
        left_hand_index_pos = robot._links["hand1_link_2_2"].get_position()
        left_hand_middle_pos = robot._links["hand1_link_3_2"].get_position()
        left_hand_ring_pos = robot._links["hand1_link_4_2"].get_position() #左手无名指的位置
        left_fingers_pos = [left_hand_index_pos, left_hand_middle_pos,left_hand_ring_pos]
        left_current_fingers = [torch.norm(ball_left_pos - finger_pos, p=2, dim=-1) for finger_pos in left_fingers_pos]
        left_current_finger_dist = sum(left_current_fingers)
        left_avg_distance= sum(left_current_fingers) / len(left_current_fingers)
        
        # print("right_current_finger_dist: ",right_current_finger_dist)
        # print("left_current_finger_dist: ",left_current_finger_dist)
        # print(f"左手平均距离: {left_avg_distance:.4f}, 右手平均距离: {right_avg_distance:.4f}")
        # 计算抓取奖励
        right_grasp_reward = 1.0 * torch.exp(- 8 * torch.clamp(right_avg_distance - 0.08, 0, None))
        left_grasp_reward = 1.0 * torch.exp(- 8 * torch.clamp(left_avg_distance - 0.08, 0, None))
        # print("right_grasp_reward: ",right_grasp_reward)
        # print("left_grasp_reward: ",left_grasp_reward)
        # 计算协同奖励
        sum_sync_distance=right_avg_distance+left_avg_distance
        hand_sync_reward= 4 * torch.exp(- 5 * torch.clamp(sum_sync_distance - 0.2, 0, None))
        # print("hand_sync_reward: ",hand_sync_reward)

        # 计算抬升奖励
        # def lift_reward_func(obj_pos):
        #     target_lift_height = 0.2
        #     target_z = initial_height + target_lift_height
        #     height_diff = torch.abs(target_z - obj_pos[2]) #还需要提起height diff高度
        #     lift_reward =  10 * torch.clamp(0.2 - height_diff, -0.01, None)
        #     return lift_reward,height_diff
        
        ####lift
        target_lift_height = 0.15
        r_lift = (obj_pos[2] - obj_start_pos[2]) / (target_lift_height)
        r_lift = 10 *np.clip(r_lift, 0.0, 1.0)
        
        # lift_left_reward,lift_left_height_diff = lift_reward_func(grasp_point_left)
        # lift_right_reward,lift_right_height_diff = lift_reward_func(grasp_point_right)
        # # 计算协同奖励
        # sum_sync_height_diff = lift_left_height_diff+lift_right_height_diff
        # lift_sync_reward = 400 * torch.clamp(0.4 - sum_sync_height_diff, -0.01, None)
        # print("lift_left_height_diff:",lift_left_height_diff)
        # print("lift_right_height_diff:",lift_right_height_diff)
        # print("lift_left_reward:",lift_left_reward)
        # print("lift_right_reward:",lift_right_reward)
        # print("lift_sync_reward:",lift_sync_reward)
        ######################################################################################
        ###############################分阶段######################################
        if right_grasp_reward>0.9 and left_grasp_reward>0.9 and hand_sync_reward>3.9:
            reward_stage_flag=1
        else:
            reward_stage_flag=0
        if reward_stage_flag==0:
            r1=left_grasp_reward + right_grasp_reward +hand_sync_reward
            r2=0
        else:
            r1=left_grasp_reward + right_grasp_reward +hand_sync_reward
            r2=  r_lift+tilt_penalty
            
        # 计算总奖励 - 同时包含抓取和抬升奖励
        total_reward = r1+r2
        
        # 更新info字典
        total_info = dict() if info is None else info
        
        if "reward_breakdown" not in total_info:
            total_info["reward_breakdown"] = {}

        total_info["reward_breakdown"].update({
            "total_reward": total_reward,
            "right_grasp_reward": right_grasp_reward,
            "left_grasp_reward": left_grasp_reward,
            "hand_sync_reward":hand_sync_reward,
            "lift_reward": r_lift,
            "tilt_penalty":float(tilt_penalty),
            "stage":reward_stage_flag,
        })
        
        return total_reward, total_info


BaseTask._step_reward = MyTask._step_reward

class MyEnv(og.Environment):
    def _post_step(self,action):
        # super()._post_step(action)
        global obj_start_pos
        # Grab observations
        obs, obs_info = self.get_obs()

        # Step the scene graph builder if necessary
        if self._scene_graph_builder is not None:
            self._scene_graph_builder.step(self.scene)

        # Grab reward, done, and info, and populate with internal info
        reward, done, info = self.task.step(self, action)
        #done:如果任务成功，则done=True。任务是否成功从env中判断，可以参考该函数。
        ball = self.scene.object_registry("name", "ball0")
        current_height = ball.get_position()[2]
        initial_height = obj_start_pos[2]
        height_threshold = 0.15
        if current_height - initial_height > height_threshold:
            done = True
        else:
            done = False

        self._populate_info(info)
        info["obs_info"] = obs_info

        if done and self._automatic_reset:
            # Add lost observation to our information dict, and reset
            info["last_observation"] = obs
            obs = self.reset()

        # Hacky way to check for time limit info to split terminated and truncated
        terminated = False
        truncated = False
        # hand_eef_left_action = action[0:6]  
        # hand_eef_right_action = action[6:12]
        # finger_left_action = action[12:23]
        # finger_right_action = action[23:34]
        # print("hand eef left action:",hand_eef_left_action)
        # print("hand eef right action:",hand_eef_right_action)
        # print("finger left action:",finger_left_action)
        # print("finger right action:",finger_right_action)
        
        if done==True:
            terminated = True
        # Increment step
        self._current_step += 1
        if (self._current_step>=600):
            truncated = True
        return obs, reward, terminated, truncated, info


def env_reset(wrapped_env, env_augmentor,index):
    global obj_start_pos
    wrapped_env.env.reset()
    print("reset env ", current_filename[0])
    # wrapped_env.fake_init(wrapped_env.env, 
    #                           output_path=os.path.join(TEST_OUT_PATH, current_filename[0]), 
    #                           first_hdf5=os.path.join(TEST_OUT_PATH, output_path_name), 
    #                           only_successes=False)
    global domain_cfg
    info = None
    #epi_cfg = domain_cfg['aug'][index]
    if random_object_flag:
        env_augmentor.restore2defaultstate(wrapped_env.env)
        if random_vision_flag:
            env_augmentor.iterate_env_aug(wrapped_env.env, domain_cfg['aug'][index])
    # move_to_initial_position(env)
    # env_augmentor.replace_object(env)

    # env_augmentor.restore2defaultstate(wrapped_env.env)
    # env_augmentor.iterate_env_aug(wrapped_env.env, domain_cfg['aug'][index])
    wrapped_env.env.robots[0].reset()
    wrapped_env.env.robots[0].keep_still()
    for _ in range(10):
        og.sim.step()
    move_to_initial_position(wrapped_env.env)
    obj = wrapped_env.env.scene.object_registry("name", "ball0")
    obj_start_pos=obj.get_position()
    wrapped_env.env._task.reset_for_new_task(obj_start_pos) # for reward compute
    info = None
    if "obj_config" in domain_cfg['aug'][index].keys():
        for obj_config in domain_cfg['aug'][index]['obj_config']:
            if "ball0" in obj_config["name"]:
                if 'category' in obj_config.keys():
                    info = (f'''{obj_config['category']}_{obj_config['model']}''')
                elif 'usd_path' in obj_config.keys():
                    info = (f'''{obj_config['usd_path']}''')
        print("info:",info)
    wrapped_env.set_h5_file_path(wrapped_env.env, os.path.join(TEST_OUT_PATH, current_filename[0]), info)


    # ball = env.scene.object_registry("name", "ball0")
    # obj_start_pos=ball.get_position()
    # env._task.set_obj_start_pos(obj_start_pos) # for reward compute

# change_obj_pose
# adjust_light_intensity
# adjust_light_color
# replace_object
# change_texture
# apply_random_aug

def setup_wandb(checkpoint_path):
    """设置wandb，连接到现有的训练运行"""
    import wandb

    # 从checkpoint路径中提取run_id
    # 假设checkpoint路径格式为: "runs/项目名称/run_id/checkpoints/step.pt"
    run_id = checkpoint_path.split('/')[-4]  # 根据实际路径结构调整

    # 初始化wandb
    wandb.init(
        project="policy_decorator",  # 使用与训练相同的项目名
        id=run_id,                   # 使用相同的run_id
        resume="allow"               # 允许恢复现有运行
    )
    return wandb

#1.load policy （base policy and res policy）
def load_diffusion_policy_model(checkpoint_path):
    import ssl
    ssl._create_default_https_context = ssl._create_unverified_context

    import requests
    from requests.packages.urllib3.exceptions import InsecureRequestWarning
    requests.packages.urllib3.disable_warnings(InsecureRequestWarning)

    # 修改 huggingface_hub 的设置
    import os
    os.environ['HF_HUB_DISABLE_SSL_VERIFICATION'] = '1'
    ckpt_path = checkpoint_path
    payload = torch.load(open(ckpt_path, 'rb'), pickle_module=dill)
    cfg = payload['cfg']

    # 打印目标类路径以进行调试
    # print(f"Target class path: {cfg._target_}")

    # 如果需要，手动修改目标类路径
    # cfg._target_ = "your.custom.path.to.ModelClass"

    cls = hydra.utils.get_class(cfg._target_)
    workspace = cls(cfg)
    workspace: BaseWorkspace
    workspace.load_payload(payload, exclude_keys=None, include_keys=None)

    if 'diffusion' in cfg.name:
        # diffusion model
        policy: BaseImagePolicy
        policy = workspace.model
        # if cfg.training.use_ema:
        #     policy = workspace.ema_model

        device = torch.device('cuda')
        policy.eval().to(device)

        # set inference params
        policy.num_inference_steps = 4 # DDIM inference iterations
        policy.n_action_steps = 8
    return policy
#2. load env
def load_environment():
    global obj_start_pos
    with open(CONFIG_FILE_PATH, "r") as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)
    env = MyEnv(configs=cfg["sim"])
    wrapped_env = DataCollectionWrapper(
        env=env,
        output_path=os.path.join(TEST_OUT_PATH, output_path_name),
        only_successes=False
    )
    #move_to_initial_position(env)
    cup = env.scene.object_registry("name", "ball0")
    obj_start_pos=cup.get_position()
    env._task.reset_for_new_task(obj_start_pos)
    # env._task.set_obj_start_pos(obj_start_pos) # for reward compute
    return env, wrapped_env

def setup_env_augmentor(default_obj_path):
    augmentor = EnvAugmentor(default_obj_config_path=default_obj_path, 
                             default_texture_path="/home/admin01/project/Benchmark_results/assets/texture/texture_d0/20241126-171022.jpg")
    return augmentor
    

#3.obs process and model inference
def process_image(img):
        # 转换为torch tensor并保持float类型
        if isinstance(img, torch.Tensor):
            img = img[..., :3].float()  # 只保留RGB通道
        else:
            img = torch.from_numpy(img[..., :3]).float()  # 转换为tensor并保留RGB通道
        
        # 调整通道顺序 [H, W, C] -> [C, H, W]
        img = img.permute(2, 0, 1)  
        # img = img.permute(2, 1, 0)  # 调整通道顺序
        
        # 检查输入范围并归一化
        max_val = img.max().item()
        is_normalized = max_val <= 1.0
        
        if not is_normalized:
            img = img / 255.0  # 归一化到[0,1]
        
        # 调整尺寸
        img = F.interpolate(
            img.unsqueeze(0),  # 添加batch维度
            size=(224, 224),   # 调整到模型期望的尺寸
            mode='bilinear',
            align_corners=False
        ).squeeze(0)  # 移除batch维度
        
        # print(f"Final processed image shape: {img.shape}")
        return img
def get_obj_state_info(obj):
    """获取物体的状态信息
    
    Args:
        obj: 物体对象
        
    Returns:
        tuple: (状态张量, 状态信息字典)
    """
    bbox_center_in_world, bbox_quat_in_world, bbox_extent_in_base_frame, bbox_center_in_desired_frame = obj.get_base_aligned_bbox(
        visual=False
    )
    linear_velocity = obj.get_linear_velocity()
    angular_velocity = obj.get_angular_velocity()
    
    obj_state_info = {
        "bbox_center_in_world": bbox_center_in_world,
        "bbox_quat_in_world": bbox_quat_in_world,
        "linear_velocity": linear_velocity,
        "angular_velocity": angular_velocity
    }
    
    obj_state_data = th.cat([bbox_center_in_world, bbox_quat_in_world, linear_velocity, angular_velocity], dim=0)
    return obj_state_data, obj_state_info

def preprocess_observation(obs,env):
    """保持在 CPU 上进行预处理"""
    obs_dict = obs if isinstance(obs, dict) else obs[0]
    
    ball = env.scene.object_registry("name", "ball0")
    ball_state_data, ball_state_info = get_obj_state_info(ball)
        # 获取任务阶段信息（如果有）
    task_stage_data = th.tensor([0.0])  # 默认值
    if "reward" in obs_dict and "reward_breakdown" in obs_dict["reward"]:
        reward_breakdown = obs_dict["reward"]["reward_breakdown"]
        if "stage" in reward_breakdown:
            task_stage_data = th.tensor(0.0)
            # task_stage_data = th.tensor([1.0 if reward_breakdown["stage"] else 0.0])
    
    # 拼接所有状态数据
    state_tensor = th.cat([ball_state_data, task_stage_data], dim=0)
    
    if 'psi' in obs_dict:
        proprio = obs_dict['psi']['proprio']
        current_obs = {
            # 'arm1_camera_rgb': process_image(obs_dict['psi']['psi:arm1_camera_rgb:Camera:0']['rgb']),
            # 'arm2_camera_rgb': process_image(obs_dict['psi']['psi:arm2_camera_rgb:Camera:0']['rgb']),
            'base_camera_rgb': process_image(obs_dict['psi']['psi:base_camera_rgb:Camera:0']['rgb']),
            'joint_qpos': proprio[:36],
            'joint_qvel': proprio[36:72], 
            'gripper_0_qpos': proprio[72:83],
            'gripper_0_qvel': proprio[83:94],
            'eef_0_pos': proprio[94:97],
            'eef_0_quat': proprio[97:101],
            'gripper_1_qpos': proprio[101:112],
            'gripper_1_qvel': proprio[112:123],
            'eef_1_pos': proprio[123:126],
            'eef_1_quat': proprio[126:130],
            'obj_state': state_tensor
        }

    obs_dict = {
        k: torch.stack([current_obs[k]], dim=0)  # 保持在 CPU
        for k in current_obs.keys()
    }
    
    obs_dict = {k: v.unsqueeze(0) for k, v in obs_dict.items()}
    return obs_dict


def move_to_initial_position(env):
    pass
    # for i in range(100):  # Maximum 100 steps
    #     action = generate_action(0.3, 0.5, 1.4)
    #     obs, _, _, _, _ = env.step(action)
    # for i in range(100):  # Maximum 100 steps
    #     action = generate_action(0.4, 0.2, 1.15)
    #     obs, _, _, _, _ = env.step(action)

def convert_to_serializable(obj):
    """将不可序列化的对象转换为可序列化的格式"""
    if isinstance(obj, torch.Tensor):
        return obj.tolist()  # 将tensor转换为list
    elif isinstance(obj, dict):
        return {k: convert_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_to_serializable(item) for item in obj)
    return obj
def get_domain_info():
    """读取并保存域信息"""
    domain_info = {}
    
    # 读取 env_aug.yaml
    with open(ENV_AUGMENTOR_CONFIG_PATH, "r") as f:
        env_aug_config = yaml.load(f, Loader=yaml.FullLoader)
        domain_info['env_augmentor'] = env_aug_config
    
    # 从 env_aug.yaml 中获取物体配置文件路径
    obj_config_path = env_aug_config['obj_config_path']
    
    # 读取物体配置文件
    with open(obj_config_path, "r") as f:
        obj_config = yaml.load(f, Loader=yaml.FullLoader)
        domain_info['obj_config'] = obj_config
    
    # 读取环境配置文件
    with open(CONFIG_FILE_PATH, "r") as f:
        env_config = yaml.load(f, Loader=yaml.FullLoader)
        domain_info['env_config'] = env_config
    
    return domain_info
def extract_proprio_obj_state(env, obs):
    """从观测字典中提取proprioception信息和obj state信息"""
    obs_dict = obs if isinstance(obs, dict) else obs[0]
    proprio_full = obs_dict['psi']['proprio']
    
    # 获取物体状态信息
    obj = env.scene.object_registry("name", "ball0")
    
    obj_state_data, obj_state_info = get_obj_state_info(obj)
    
    # 获取任务阶段信息（如果有）
    task_stage_data = th.tensor([0.0])  # 默认值
    if "reward" in obs_dict and "reward_breakdown" in obs_dict["reward"]:
        reward_breakdown = obs_dict["reward"]["reward_breakdown"]
        if "stage" in reward_breakdown:
            task_stage_data = th.tensor([1.0 if reward_breakdown["stage"] else 0.0])
    
    # 拼接所有状态数据
    state_tensor = th.cat([obj_state_data, task_stage_data], dim=0)
    
    # 拼接所有状态数据 - 包括双臂数据
    proprio_obj_state = th.cat([
        proprio_full[:36],     # joint_qpos
        proprio_full[36:72],   # joint_qvel
        proprio_full[72:83],   # gripper_0_qpos
        proprio_full[83:94],   # gripper_0_qvel
        proprio_full[94:97],   # eef_0_pos
        proprio_full[97:101],  # eef_0_quat
        proprio_full[101:112], # gripper_1_qpos
        proprio_full[112:123], # gripper_1_qvel
        proprio_full[123:126], # eef_1_pos
        proprio_full[126:130], # eef_1_quat
        state_tensor
    ])
    
    return proprio_obj_state

def complete_loop(wrapped_env):
    """完成一次数据收集循环并保存"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"data_{timestamp}.hdf5"
    
    print(f"Completing loop. Current data saved in {current_filename[0]}")
    time.sleep(0.5)
    wrapped_env.save_data()
    current_filename[0] = filename

def get_domain_name(config_path):
    """从配置文件路径中提取域名"""
    # 获取文件名（不包含路径）
    filename = os.path.basename(config_path)
    # 移除文件扩展名
    filename_without_ext = os.path.splitext(filename)[0]
    # 获取最后一段（例如从 'env_aug_d2-d1' 获取 'd2-d1'）
    domain_name = filename_without_ext.split('_')[-1]
    return domain_name

def _check_failure(env):
    ############### 失败时reset环境的条件设计
    """
    检查任务是否失败的函数
    
    可以方便地添加各种失败场景的判断
    
    Args:
        env (Environment): 环境实例
    
    Returns:
        bool: 是否失败
        str: 失败原因
    """
    global reward_stage_flag,obj_start_pos
    robot = env.robots[0]
    ball = env.scene.object_registry("name", "ball0")
    bbox_center_in_world, bbox_quat_in_world, bbox_extent_in_base_frame, bbox_center_in_desired_frame = ball.get_base_aligned_bbox(visual=False)
    obj_pos = torch.tensor(ball.get_position())
    # 右手：计算当前距离：计算物体到右手三个手指（拇指、食指、中指）距离的总和
    right_hand_thumb_pos = robot._links["hand2_link_1_3"].get_position() #右手拇指的位置
    right_hand_index_pos = robot._links["hand2_link_2_2"].get_position() #右手食指的位置
    right_hand_middle_pos = robot._links["hand2_link_3_2"].get_position() #右手中指的位置
    right_hand_ring_pos = robot._links["hand2_link_4_2"].get_position() #右手无名指的位置
        
    # 左手：计算当前距离：计算物体到右手三个手指（拇指、食指、中指）距离的总和
    left_hand_thumb_pos = robot._links["hand1_link_1_3"].get_position()
    left_hand_index_pos = robot._links["hand1_link_2_2"].get_position()
    left_hand_middle_pos = robot._links["hand1_link_3_2"].get_position()
    left_hand_ring_pos = robot._links["hand1_link_4_2"].get_position() #左手无名指的位置

    distance_between_hands =abs(left_hand_middle_pos[1] - right_hand_middle_pos[1])
    ####lif
    target_lift_height = 0.15
    r_lift_ratio = (obj_pos[2] - obj_start_pos[2]) / target_lift_height 
    if distance_between_hands < bbox_extent_in_base_frame[1]/2 and r_lift_ratio < 0.1:
        return True
    else:
        return False
def is_stage_grasp_done(env):
        robot = env.robots[0]
        obj = env.scene.object_registry("name", "ball0")
        obj_pos = torch.tensor(obj.get_position())
        board=env.scene.object_registry("name", "ball1")
        initial_height = board.get_position()[2]
        bbox_center_in_world, bbox_quat_in_world, bbox_extent_in_base_frame, bbox_center_in_desired_frame = obj.get_base_aligned_bbox(visual=False)
        grasp_point_left=obj_pos.clone()
        grasp_point_left[1]=grasp_point_left[1]+bbox_extent_in_base_frame[1]/2 +0.015
        grasp_point_left[2]=grasp_point_left[2]-0.015
        grasp_point_right=obj_pos.clone()
        grasp_point_right[1]=grasp_point_right[1]-bbox_extent_in_base_frame[1]/2 -0.015
        grasp_point_right[2]=grasp_point_right[2]-0.015
        ball_left_pos=grasp_point_left
        ball_right_pos=grasp_point_right
        
        # 右手：计算当前距离：计算物体到右手三个手指（拇指、食指、中指）距离的总和
        right_hand_thumb_pos = robot._links["hand2_link_1_3"].get_position() #右手拇指的位置
        right_hand_index_pos = robot._links["hand2_link_2_2"].get_position() #右手食指的位置
        right_hand_middle_pos = robot._links["hand2_link_3_2"].get_position() #右手中指的位置
        right_hand_ring_pos = robot._links["hand2_link_4_2"].get_position() #右手无名指的位置
        right_fingers_pos = [right_hand_index_pos, right_hand_middle_pos,right_hand_ring_pos]
        right_current_fingers = [torch.norm(ball_right_pos - finger_pos, p=2, dim=-1) for finger_pos in right_fingers_pos]
        right_current_finger_dist = sum(right_current_fingers)
        right_avg_distance= sum(right_current_fingers) / len(right_current_fingers)
        
        # 左手：计算当前距离：计算物体到右手三个手指（拇指、食指、中指）距离的总和
        left_hand_thumb_pos = robot._links["hand1_link_1_3"].get_position()
        left_hand_index_pos = robot._links["hand1_link_2_2"].get_position()
        left_hand_middle_pos = robot._links["hand1_link_3_2"].get_position()
        left_hand_ring_pos = robot._links["hand1_link_4_2"].get_position() #左手无名指的位置
        left_fingers_pos = [left_hand_index_pos, left_hand_middle_pos,left_hand_ring_pos]
        left_current_fingers = [torch.norm(ball_left_pos - finger_pos, p=2, dim=-1) for finger_pos in left_fingers_pos]
        left_current_finger_dist = sum(left_current_fingers)
        left_avg_distance= sum(left_current_fingers) / len(left_current_fingers)
        
        # print("right_current_finger_dist: ",right_current_finger_dist)
        # print("left_current_finger_dist: ",left_current_finger_dist)
        print(f"左手平均距离: {left_avg_distance:.4f}, 右手平均距离: {right_avg_distance:.4f}")
        if left_avg_distance < 0.1 and right_avg_distance < 0.1:
            return True
        else:
            return False
        # 计算抓取奖励

def convert_24d_to_34d_action(action, robot):
    """
    将24维动作转换为34维动作
    
    Args:
        action (torch.Tensor): 24维输入动作 [左臂(6), 右臂(6), 左手(6), 右手(6)]
        robot: 机器人对象，用于获取关节限制
        
    Returns:
        torch.Tensor: 34维输出动作 [左臂(6), 右臂(6), 左手(11), 右手(11)]
    """
    # 提取各部分动作
    arm_left_command = action[0:6]
    arm_right_command = action[6:12]
    hand_left_command = action[12:18]
    hand_right_command = action[18:24]
    
    # 获取关节限制
    upper_limit = robot.joint_upper_limits
    lower_limit = robot.joint_lower_limits
    
    # 初始化11维手指动作
    hand_left_command_final = th.zeros(11)
    hand_right_command_final = th.zeros(11)
    
    # 左手直接映射
    hand_left_command_final[0] = hand_left_command[0]
    hand_left_command_final[1] = hand_left_command[1]
    hand_left_command_final[3] = hand_left_command[2]
    hand_left_command_final[5] = hand_left_command[3]
    hand_left_command_final[7] = hand_left_command[4]
    hand_left_command_final[9] = hand_left_command[5]
    
    # 右手直接映射
    hand_right_command_final[0] = hand_right_command[0]
    hand_right_command_final[1] = hand_right_command[1]
    hand_right_command_final[3] = hand_right_command[2]
    hand_right_command_final[5] = hand_right_command[3]
    hand_right_command_final[7] = hand_right_command[4]
    hand_right_command_final[9] = hand_right_command[5]
    
    # 左手关节归一化映射
    # 大拇指
    norm_temp = (hand_left_command_final[1] - lower_limit[24]) / (upper_limit[24] - lower_limit[24])
    hand_left_command_final[2] = norm_temp * (upper_limit[34] - lower_limit[34]) + lower_limit[34]
    # 食指
    norm_temp = (hand_left_command_final[3] - lower_limit[15]) / (upper_limit[15] - lower_limit[15])
    hand_left_command_final[4] = norm_temp * (upper_limit[25] - lower_limit[25]) + lower_limit[25]
    # 中指
    norm_temp = (hand_left_command_final[5] - lower_limit[16]) / (upper_limit[16] - lower_limit[16])
    hand_left_command_final[6] = norm_temp * (upper_limit[26] - lower_limit[26]) + lower_limit[26]
    # 无名指
    norm_temp = (hand_left_command_final[7] - lower_limit[17]) / (upper_limit[17] - lower_limit[17])
    hand_left_command_final[8] = norm_temp * (upper_limit[27] - lower_limit[27]) + lower_limit[27]
    # 小拇指
    norm_temp = (hand_left_command_final[9] - lower_limit[18]) / (upper_limit[18] - lower_limit[18])
    hand_left_command_final[10] = norm_temp * (upper_limit[28] - lower_limit[28]) + lower_limit[28]
    
    # 右手关节归一化映射
    # 食指
    norm_temp = (hand_right_command_final[1] - lower_limit[20]) / (upper_limit[20] - lower_limit[20])
    hand_right_command_final[2] = norm_temp * (upper_limit[30] - lower_limit[30]) + lower_limit[30]
    # 中指
    norm_temp = (hand_right_command_final[3] - lower_limit[21]) / (upper_limit[21] - lower_limit[21])
    hand_right_command_final[4] = norm_temp * (upper_limit[31] - lower_limit[31]) + lower_limit[31]
    # 无名指
    norm_temp = (hand_right_command_final[5] - lower_limit[22]) / (upper_limit[22] - lower_limit[22])
    hand_right_command_final[6] = norm_temp * (upper_limit[32] - lower_limit[32]) + lower_limit[32]
    # 小拇指
    norm_temp = (hand_right_command_final[7] - lower_limit[23]) / (upper_limit[23] - lower_limit[23])
    hand_right_command_final[8] = norm_temp * (upper_limit[33] - lower_limit[33]) + lower_limit[33]
    # 大拇指
    norm_temp = (hand_right_command_final[9] - lower_limit[29]) / (upper_limit[29] - lower_limit[29])
    hand_right_command_final[10] = norm_temp * (upper_limit[35] - lower_limit[35]) + lower_limit[35]
    
    # 组合所有动作
    return th.cat((arm_left_command, arm_right_command, hand_left_command_final, hand_right_command_final), 0)
       
def evaluate(wrapped_env, base_policy, default_obj_paths ,device="cuda:0", start_idx=0, res_policy=None):
    """修改后的评估函数，支持从指定索引开始"""
    # 加载domain配置
    global domain_cfg
    if domain_cfg is None:
        with open(ENV_AUGMENTOR_CONFIG_PATH, "r") as f:
            domain_cfg = yaml.safe_load(f)
    
    # 获取aug配置总数
    total_aug_configs = len(domain_cfg['aug'])
    print(f"总共有 {total_aug_configs} 个环境增强配置需要评估")
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # 创建结果目录
    os.makedirs(log_save_path, exist_ok=True)
    os.makedirs(json_save_path, exist_ok=True)
    
    # 文件路径
    json_filename = f"{json_save_path}/eval_{timestamp}.json"
    log_filename = f"{log_save_path}/eval_{timestamp}.log"
    
    # 初始化结果记录
    episode_rewards = []
    episode_successes = []
    episode_lengths = []
    episode_times = []
    episode_log = []
    res_norm_ratios = []
    res_norms = []
    base_norms = []
    # 检查是否从中间开始，尝试加载已有结果
    if start_idx > 0:
        print(f"从配置索引 {start_idx} 开始继续评估")
        
        # 查找最新的结果文件
        existing_jsons = sorted([f for f in os.listdir(json_save_path) if f.startswith("eval_")])
        if existing_jsons:
            latest_json = os.path.join(json_save_path, existing_jsons[-1])
            print(f"尝试加载已有结果: {latest_json}")
            
            try:
                with open(latest_json, "r") as f:
                    existing_results = json.load(f)
                
                # 确保已完成的评估数与要开始的索引匹配
                if existing_results.get('num_configs_completed', 0) == start_idx:
                    episode_rewards = existing_results.get('episode_rewards', [])
                    episode_successes = existing_results.get('episode_successes', [])
                    episode_lengths = existing_results.get('episode_lengths', [])
                    episode_times = existing_results.get('episode_times', [])
                    episode_log = existing_results.get('episode_log', [])
                    
                    # 使用现有文件而不是创建新文件
                    json_filename = latest_json
                    log_filename = os.path.join(log_save_path, os.path.basename(latest_json).replace('.json', '.log'))
                    
                    print(f"成功加载已有结果，已完成 {len(episode_successes)} 个配置的评估")
                else:
                    print(f"警告: 已有结果完成的配置数 ({existing_results.get('num_configs_completed', 0)}) 与指定的开始索引 ({start_idx}) 不匹配")
                    print("将创建新的结果文件")
            except Exception as e:
                print(f"加载已有结果时出错: {e}")
                print("将创建新的结果文件")
    
    # 如果是新的评估或加载失败，写入日志头部
    if len(episode_log) == 0:
        eval_domain = get_domain_name(ENV_AUGMENTOR_CONFIG_PATH)
        with open(log_filename, "w") as f:
            f.write(f"评估领域: {eval_domain}\n")
            f.write(f"配置文件: {ENV_AUGMENTOR_CONFIG_PATH}\n")
            f.write(f"IL checkpoint: {base_policy_ckpt}\n")
            f.write(f"开始索引: {start_idx}\n")
            f.write("评估结果将在运行过程中更新...\n\n")
            f.write("配置执行日志:\n")
    else:
        # 如果从中间继续，只需追加一个分隔符
        with open(log_filename, "a") as f:
            f.write(f"\n--- 从索引 {start_idx} 继续评估 ---\n")
    complete_loop(wrapped_env)
    for i, default_obj_path in enumerate(default_obj_paths):
        env_augmentor = setup_env_augmentor(default_obj_path)
        print(f"评估第 {i+1} 个默认物体配置")
        for idx in range(start_idx, total_aug_configs):
            episode_start_time = time.time()
            print(f"\n评估配置 {idx+1+i*total_aug_configs}/{total_aug_configs * len(default_obj_paths)}")
            
            episode_reward = 0
            episode_length = 0
            episode_success = False
            
            # 重置环境并应用特定增强
            env_reset(wrapped_env, env_augmentor, idx)
            
            obs = wrapped_env.env.get_obs()
            obs_dict = obs if isinstance(obs, dict) else obs[0]
            wrapped_env.current_obs = deepcopy(obs_dict)
            done = False
            fail=False
            while not done and not fail:
                # 保留原有的观察处理逻辑
                wrapped_env.set_data_mask(1)
                obs_dict = preprocess_observation(obs,wrapped_env.env)
                obs_dict = {k: torch.as_tensor(v, device=device) if isinstance(v, (np.ndarray, torch.Tensor)) else v 
                            for k, v in obs_dict.items()}
                
                with torch.no_grad():
                    base_act_seq = base_policy.predict_action(obs_dict)['action']
                    base_act_seq = base_act_seq.squeeze(0)[0].cpu()
                    base_act_seq = convert_24d_to_34d_action(base_act_seq, wrapped_env.env.scene.robots[0])


                if res_policy is not None:
                    with torch.no_grad():
                        proprio_obj_state = extract_proprio_obj_state(wrapped_env.env, obs)
                        proprio_obj_state_tensor = torch.FloatTensor(proprio_obj_state).unsqueeze(0).to(device)
                        res_action = res_policy.get_eval_action(proprio_obj_state_tensor)
                        res_action = res_action.detach().cpu()
                        if mix:
                            ###### stage #########
                            done_flag = is_stage_grasp_done(wrapped_env.env)
                            ###### stage #########
                            # Scale residual action
                            if done_flag:
                                print("done_flag:",done_flag)
                                scaled_res_action = NORMAL_RES_CALE * res_action
                            else:
                                scaled_res_action = RES_SCALE * res_action
                        else:
                            scaled_res_action = RES_SCALE * res_action
                        # Calculate norms
                        res_norm = torch.norm(scaled_res_action).item()
                        base_norm = torch.norm(base_act_seq).item()
                        norm_ratio = res_norm / (base_norm + 1e-8)
                        
                        # 确保维度匹配 - 重要修改
                        scaled_res_action = scaled_res_action.squeeze(0)  # 从[1,34]变为[34]
                        
                        # Combine actions
                        combined_action = base_act_seq + scaled_res_action
                        print("base_act_seq:",base_act_seq)
                        print("scaled_res_action:",scaled_res_action)
                        # print("combined_action:",combined_action)
                obs, reward, termination, truncation, info = wrapped_env.step(combined_action)

                # 记录动作范数
                res_norms.append(res_norm)
                res_norm_ratios.append(norm_ratio)
                base_norms.append(base_norm)
                # 在调用step之前，确保wrapped_env知道这是有效的数据
                #wrapped_env.set_data_mask(True)
                
                reward = reward - 1.0
                episode_reward += reward
                episode_length += 1
                episode_success = termination
                
                if episode_success:
                    print("成功")
                done = termination or truncation
                fail = _check_failure(wrapped_env)
                if fail:
                    truncation=True
                # if cup_lift_amount > 0.10:
                #     wrapped_env.set_data_mask(2)
            episode_time = time.time() - episode_start_time
            episode_times.append(episode_time)
            episode_rewards.append(episode_reward)
            episode_successes.append(episode_success)
            episode_lengths.append(episode_length)
            
            # 记录环境增强配置和结果
            result_str = "success" if episode_success else "failure"
            episode_log.append((env_augmentor.log, result_str))

            if(episode_success):
                complete_loop(wrapped_env)
            elif(truncation):
                wrapped_env.step_count -= len(wrapped_env.current_traj_history)
                wrapped_env.current_traj_history = []
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                rm_path = os.path.join(TEST_OUT_PATH, current_filename[0])
                if current_filename[0] != output_path_name and os.path.exists(rm_path):
                    os.remove(rm_path)
                current_filename[0] = f"data_{timestamp}.hdf5"
            #env_reset(wrapped_env, env_augmentor)
            
            # 实时更新日志文件
            with open(log_filename, "a") as f:
                f.write(f"{idx+1}. {str(env_augmentor.log)} => {result_str}\n")
            
            # 计算当前的整体结果
            current_results = {
                'num_configs_completed': idx + 1,
                'total_configs': total_aug_configs,
                'success_rate': float(np.mean(episode_successes) * 100),
                'mean_reward': float(np.mean(episode_rewards)),
                'std_reward': float(np.std(episode_rewards)),
                'mean_length': float(np.mean(episode_lengths)),
                'mean_episode_time': float(np.mean(episode_times)),
                'std_episode_time': float(np.std(episode_times)),
                'episode_log': episode_log,
                # 保存原始数据以便恢复
                'episode_rewards': episode_rewards,
                'episode_successes': episode_successes,
                'episode_lengths': episode_lengths,
                'episode_times': episode_times
            }
            
            # 实时更新JSON文件
            with open(json_filename, "w") as f:
                json.dump(convert_to_serializable(current_results), f, indent=4)
            
            # 打印当前成功率
            print(f"当前成功率: {current_results['success_rate']:.1f}% ({sum(episode_successes)}/{len(episode_successes)})")
    
    # 完成所有评估后，更新最终结果统计
    final_results = current_results
    final_results['num_configs'] = total_aug_configs
    
    # 更新日志文件，添加统计信息
    with open(log_filename, "a") as f:
        f.write("\n\n最终评估结果:\n")
        f.write(f"配置数量: {final_results['num_configs']}\n")
        f.write(f"成功率: {final_results['success_rate']:.1f}%\n")
        f.write(f"平均奖励: {final_results['mean_reward']:.2f} ± {final_results['std_reward']:.2f}\n")
        f.write(f"平均步数: {final_results['mean_length']:.1f}\n")
        f.write(f"平均执行时间: {final_results['mean_episode_time']:.2f} ± {final_results['std_episode_time']:.2f} 秒\n")
        
        # 按配置类型分组统计
        f.write("\n按配置类型分组的成功率:\n")
        config_type_stats = defaultdict(lambda: {'success': 0, 'total': 0})
        
        for log_entry, result in final_results['episode_log']:
            # 尝试提取配置类型
            config_type = None
            for key in log_entry:
                config_type = key
                break
                
            if config_type:
                config_type_stats[config_type]['total'] += 1
                if result == "success":
                    config_type_stats[config_type]['success'] += 1
        
        for config_type, stats in config_type_stats.items():
            success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
            f.write(f"  • {config_type}: {success_rate:.1f}% ({stats['success']}/{stats['total']})\n")
    
    # 更新最终的JSON文件
    with open(json_filename, "w") as f:
        json.dump(convert_to_serializable(final_results), f, indent=4)
    
    return final_results

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

def test_env_augmentor():
    env_augmentor = setup_env_augmentor()
    env = load_environment()
    for i in range(10):
        env_reset( env, env_augmentor)
        print(f"Episode {i+1} start")
        time.sleep(15)
        print(f"Episode {i+1} completed")
def get_action_space_boundaries(env):
    """
    获取双臂系统的动作空间边界
    输出动作是34维,顺序：[左臂EEF(6), 右臂EEF(6), 左手关节(11), 右手关节(11)]
    返回动作缩放和偏移
    """
    # 从环境获取所有关节的限制
    joint_upper_limits = env.scene.robots[0].joint_upper_limits
    joint_lower_limits = env.scene.robots[0].joint_lower_limits
    
    # 创建34维的动作空间
    action_dim = 34
    l = np.zeros(action_dim)
    h = np.zeros(action_dim)
    
    # === 左臂EEF (索引0-5) ===
    # 位置部分 (0-2): 范围为-1到1
    l[0:3] = -1.0
    h[0:3] = 1.0
    
    # 方向部分 (3-5): 范围为-1.5到1.5
    l[3:6] = -1.5
    h[3:6] = 1.5
    
    # === 右臂EEF (索引6-11) ===
    # 位置部分 (6-8): 范围为-1到1
    l[6:9] = -1.0
    h[6:9] = 1.0
    
    # 方向部分 (9-11): 范围为-1.5到1.5
    l[9:12] = -1.5
    h[9:12] = 1.5
    
    # === 左手关节 (索引12-22) ===
    # 使用实际机器人关节限制
    # 左手关节对应于joint_upper_limits中的索引14-18和24-28 (再加上34)
    left_hand_joint_indices = [14, 15, 16, 17, 18, 24, 25, 26, 27, 28, 34]
    for i, joint_idx in enumerate(left_hand_joint_indices):
        l[12+i] = joint_lower_limits[joint_idx]
        h[12+i] = joint_upper_limits[joint_idx]
    
    # === 右手关节 (索引23-33) ===
    # 使用实际机器人关节限制
    # 右手关节对应于joint_upper_limits中的索引19-23和29-33 (再加上35)
    right_hand_joint_indices = [19, 20, 21, 22, 23, 29, 30, 31, 32, 33, 35]
    for i, joint_idx in enumerate(right_hand_joint_indices):
        l[23+i] = joint_lower_limits[joint_idx]
        h[23+i] = joint_upper_limits[joint_idx]
    
    # 转换为PyTorch张量
    l_tensor = torch.tensor(l, dtype=torch.float32)
    h_tensor = torch.tensor(h, dtype=torch.float32)

    # 计算action_scale和action_bias (用于将[-1,1]范围的动作转换到实际范围)
    action_scale = (h_tensor - l_tensor) / 2.0
    action_bias = (h_tensor + l_tensor) / 2.0
    
    return action_scale, action_bias
LOG_STD_MAX = 2
LOG_STD_MIN = -20
class Actor(nn.Module):
    def __init__(self, env, obs_dim, action_dim):

        super().__init__()
        input_dim = obs_dim
        # print("input_dim:",input_dim)
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        # 使用明确的动作维度
        self.fc_mean = layer_init(nn.Linear(256, action_dim), std=0.01)
        self.fc_logstd = layer_init(nn.Linear(256, action_dim), std=0.01)
        
        # 是否使用动作空间映射
        self.use_action_scaling = True
        
        # 如果使用动作空间映射，计算action_scale和action_bias
        if self.use_action_scaling:
            action_scale, action_bias = get_action_space_boundaries(env)
            self.register_buffer("action_scale", action_scale)
            self.register_buffer("action_bias", action_bias)

    def forward(self, x):
        x = self.backbone(x) 
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats
        return mean, log_std

    def get_action(self, x):
        mean, log_std = self(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        
        if self.use_action_scaling:
            # 应用动作空间映射
            action = y_t * self.action_scale + self.action_bias
            # 修正log_prob计算
            log_prob = normal.log_prob(x_t)
            log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)))
        else:
            # 不使用动作空间映射，直接使用tanh输出
            action = y_t
            # 计算log_prob (使用标准tanh变换修正)
            log_prob = normal.log_prob(x_t)
            log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
        
        log_prob = log_prob.sum(1, keepdim=True)
        
        # 同样处理mean
        if self.use_action_scaling:
            mean = torch.tanh(mean) * self.action_scale + self.action_bias
        else:
            mean = torch.tanh(mean)
        
        return action, log_prob, mean

    def get_eval_action(self, x):
        """用于评估的确定性动作"""
        x = self.backbone(x)
        mean = self.fc_mean(x)
        mean = torch.tanh(mean)
        
        # 根据设置决定是否应用动作空间映射
        if self.use_action_scaling:
            action = mean * self.action_scale + self.action_bias
        else:
            action = mean
            
        return action

    def to(self, device):
        return super().to(device)

def get_default_obj_paths_from_directory(directory_path):
    """
    从指定目录中获取所有YAML配置文件的路径
    
    Args:
        directory_path: 配置文件所在的目录路径
        
    Returns:
        yaml_files: 所有YAML文件的完整路径列表
    """
    yaml_files = []
    
    # 检查目录是否存在
    if not os.path.exists(directory_path):
        print(f"警告: 指定的目录 {directory_path} 不存在!")
        return yaml_files
    
    # 遍历目录中的所有文件
    for filename in sorted(os.listdir(directory_path)):
        if filename.endswith('.yaml') or filename.endswith('.yml'):
            full_path = os.path.join(directory_path, filename)
            yaml_files.append(full_path)
    
    print(f"从 {directory_path} 中找到 {len(yaml_files)} 个YAML配置文件")
    return yaml_files

def main():
    # 解析命令行参数
    parser = argparse.ArgumentParser(description='评估IL模型')
    parser.add_argument('--start_idx', type=int, default=0, help='从指定索引开始评估')
    parser.add_argument('--config_dir', type=str, 
                        default=DEFAULT_OBJ_CONFIG_DIR,
                        help='包含默认对象配置文件的目录')
    args = parser.parse_args()
    
    # 设置设备
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print(f"使用设备: {device}")
    
    # 加载策略模型
    print(f"加载策略模型: {base_policy_ckpt}")
    base_policy = load_diffusion_policy_model(base_policy_ckpt)
    base_policy.to(device)
    # 加载残差策略

    # 创建环境
    print("创建环境...")
    env, wrapped_env = load_environment()
    #env_augmentor = setup_env_augmentor()
    # 加载残差策略
    res_checkpoint = torch.load(res_policy_path)
    # Define network dimensions
    proprio_obj_state_dim = (36 + 36 + 11 + 11 + 3 + 4 + 11 + 11 + 3 + 4 + 13+1)  # joint_qpos + joint_qvel + gripper_0_qpos + gripper_0_qvel + eef_0_pos + eef_0_quat + gripper_1_qpos + gripper_1_qvel + eef_1_pos + eef_1_quat + obj_state
    action_dim = 34  # 机器人的动作维度
    res_policy = Actor(wrapped_env,proprio_obj_state_dim, action_dim)  # 需要定义你的残差策略网络
    res_policy.load_state_dict(res_checkpoint['res_actor'])
    res_policy.to(device)
    res_policy.eval()
    
    # 加载环境增强配置
    global domain_cfg
    with open(ENV_AUGMENTOR_CONFIG_PATH, "r") as f:
        domain_cfg = yaml.safe_load(f)
    
    # 从目录加载所有配置文件
    default_obj_paths = get_default_obj_paths_from_directory(args.config_dir)
    
    if not default_obj_paths:
        print(f"错误: 在 {args.config_dir} 中未找到任何YAML配置文件!")
        return
        
    # 执行评估，从指定索引开始
    print(f"开始评估，起始索引: {args.start_idx}...")
    print(f"使用配置文件目录: {args.config_dir}")
    print(f"找到 {len(default_obj_paths)} 个配置文件")
    
    results = evaluate(wrapped_env, base_policy, default_obj_paths, device, start_idx=args.start_idx, res_policy=res_policy)
    
    print(f"评估完成，成功率: {results['success_rate']:.1f}%")
    print(f"结果已保存到 {log_save_path} 和 {json_save_path} 目录")

if __name__ == "__main__":
    main()